# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Attention-based sequence-to-sequence model with dynamic RNN support."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from . import model_helper
from . import model

import tensorflow as tf

__all__ = ["AttentionModel"]


class AttentionModel(model.Model):
    """Sequence-to-sequence dynamic model with attention.

    This class implements a multi-layer recurrent neural network as encoder,
    and an attention-based decoder. This is the same as the model described in
    (Luong et al., EMNLP'2015) paper: https://arxiv.org/pdf/1508.04025v5.pdf.
    This class also allows to use GRU cells in addition to LSTM cells with
    support for dropout.
    """

    def __init__(self,
                 hparams,
                 mode,
                 iterator,
                 source_vocab_table,
                 target_vocab_table,
                 reverse_target_vocab_table=None,
                 scope=None,
                 extra_args=None):
        self.has_attention = hparams.attention_architecture and hparams.attention

        # Set attention_mechanism_fn
        if self.has_attention:
            if extra_args and extra_args.attention_mechanism_fn:
                self.attention_mechanism_fn = extra_args.attention_mechanism_fn
            else:
                self.attention_mechanism_fn = create_attention_mechanism

        super(AttentionModel, self).__init__(
            hparams=hparams,
            mode=mode,
            iterator=iterator,
            source_vocab_table=source_vocab_table,
            target_vocab_table=target_vocab_table,
            reverse_target_vocab_table=reverse_target_vocab_table,
            scope=scope,
            extra_args=extra_args)

    def _prepare_beam_search_decoder_inputs(
            self, beam_width, memory, source_sequence_length, encoder_state):
        memory = tf.contrib.seq2seq.tile_batch(
            memory, multiplier=beam_width)
        source_sequence_length = tf.contrib.seq2seq.tile_batch(
            source_sequence_length, multiplier=beam_width)
        encoder_state = tf.contrib.seq2seq.tile_batch(
            encoder_state, multiplier=beam_width)
        batch_size = self.batch_size * beam_width
        return memory, source_sequence_length, encoder_state, batch_size

    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                            source_sequence_length):
        """Build a RNN cell with attention mechanism that can be used by decoder."""
        # No Attention
        if not self.has_attention:
            return super(AttentionModel, self)._build_decoder_cell(
                hparams, encoder_outputs, encoder_state, source_sequence_length)
        elif hparams.attention_architecture != "standard":
            raise ValueError(
                "Unknown attention architecture %s" % hparams.attention_architecture)

        num_units = hparams.num_units
        num_layers = self.num_decoder_layers
        num_residual_layers = self.num_decoder_residual_layers
        infer_mode = hparams.infer_mode

        dtype = tf.float32

        # Ensure memory is batch-major
        if self.time_major:
            memory = tf.transpose(encoder_outputs, [1, 0, 2])
        else:
            memory = encoder_outputs

        if (self.mode == tf.contrib.learn.ModeKeys.INFER and
                infer_mode == "beam_search"):
            memory, source_sequence_length, encoder_state, batch_size = (
                self._prepare_beam_search_decoder_inputs(
                    hparams.beam_width, memory, source_sequence_length,
                    encoder_state))
        else:
            batch_size = self.batch_size

        # Attention
        attention_mechanism = self.attention_mechanism_fn(
            hparams.attention, num_units, memory, source_sequence_length, self.mode)

        cell = model_helper.create_rnn_cell(
            unit_type=hparams.unit_type,
            num_units=num_units,
            num_layers=num_layers,
            num_residual_layers=num_residual_layers,
            forget_bias=hparams.forget_bias,
            dropout=hparams.dropout,
            num_gpus=self.num_gpus,
            mode=self.mode,
            single_cell_fn=self.single_cell_fn)

        # Only generate alignment in greedy INFER mode.
        alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and
                             infer_mode != "beam_search")
        cell = tf.contrib.seq2seq.AttentionWrapper(
            cell,
            attention_mechanism,
            attention_layer_size=num_units,
            alignment_history=alignment_history,
            output_attention=hparams.output_attention,
            name="attention")

        # TODO(thangluong): do we need num_layers, num_gpus?
        cell = tf.contrib.rnn.DeviceWrapper(cell,
                                            model_helper.get_device_str(
                                                num_layers - 1, self.num_gpus))

        if hparams.pass_hidden_state:
            decoder_initial_state = cell.zero_state(batch_size, dtype).clone(
                cell_state=encoder_state)
        else:
            decoder_initial_state = cell.zero_state(batch_size, dtype)

        return cell, decoder_initial_state

    def _get_infer_summary(self, hparams):
        if not self.has_attention or hparams.infer_mode == "beam_search":
            return tf.no_op()
        return _create_attention_images_summary(self.final_context_state)


def create_attention_mechanism(attention_option, num_units, memory,
                               source_sequence_length, mode):
    """Create attention mechanism based on the attention_option."""
    del mode  # unused

    # Mechanism
    if attention_option == "luong":
        attention_mechanism = tf.contrib.seq2seq.LuongAttention(
            num_units, memory, memory_sequence_length=source_sequence_length)
    elif attention_option == "scaled_luong":
        attention_mechanism = tf.contrib.seq2seq.LuongAttention(
            num_units,
            memory,
            memory_sequence_length=source_sequence_length,
            scale=True)
    elif attention_option == "bahdanau":
        attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
            num_units, memory, memory_sequence_length=source_sequence_length)
    elif attention_option == "normed_bahdanau":
        attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
            num_units,
            memory,
            memory_sequence_length=source_sequence_length,
            normalize=True)
    else:
        raise ValueError("Unknown attention option %s" % attention_option)

    return attention_mechanism


def _create_attention_images_summary(final_context_state):
    """create attention image and attention summary."""
    attention_images = (final_context_state.alignment_history.stack())
    # Reshape to (batch, src_seq_len, tgt_seq_len,1)
    attention_images = tf.expand_dims(
        tf.transpose(attention_images, [1, 2, 0]), -1)
    # Scale to range [0, 255]
    attention_images *= 255
    attention_summary = tf.summary.image("attention_images", attention_images)
    return attention_summary
